# experiment2是baseline+part+music
# 用的backbone还是experiment1训练好的backbone
# 5way-5shot-15query-600个task, 平均acc: 0.8228

from lib.algorithms import ALGORITHMS
from .experiment4_net import Experiment2
from .experiment4_loss import Experiment2Loss
from .experiment4_metric import Experiment2Acc
from lib.sampler import NormalSampler
from lib.sampler import EpisodeSampler


@ALGORITHMS.register('experiment2')
def expirement2(cfg):
    return {
        'net': Experiment2(cfg),
        'loss': Experiment2Loss(cfg),
        'metric': Experiment2Acc(cfg),
        'train_sampler': NormalSampler,
        'validate_sampler': NormalSampler,
        'test_sampler': EpisodeSampler
    }
